{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inspect Classification Training Dataset\n",
    "\n",
    "This notebook is meant to be run after the classification dataset has been created but before training a classifier. Copy this notebook to the same folder as the classification dataset, for examples:\n",
    "\n",
    "```\n",
    "MegaDetector/\n",
    "    classification/\n",
    "        BASE_LOGDIR/\n",
    "            classification_ds.csv\n",
    "            inspect_dataset.ipynb  # COPY THIS NOTEBOOK TO HERE\n",
    "            splits.json\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "!pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "from classification.train_utils import load_splits, plot_img_grid\n",
    "\n",
    "\n",
    "disp_context = pd.option_context(\n",
    "    'display.float_format', '{:0.2f}'.format,\n",
    "    'display.max_rows', 1000)\n",
    "sns.set(style='darkgrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SPLITS = ['train', 'val', 'test']\n",
    "csv_path = 'classification_ds.csv'\n",
    "splits_json_path = 'splits.json'\n",
    "\n",
    "crops_dir = '/path/to/crops'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset and splits files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(csv_path, index_col=False, float_precision='high')\n",
    "\n",
    "# merge dataset and location into a (dataset, location) tuple\n",
    "df['dataset_location'] = list(zip(df['dataset'], df['location']))\n",
    "\n",
    "label_order = sorted(df['label'].unique())\n",
    "num_labels = len(label_order)\n",
    "\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_to_locs = load_splits(splits_json_path)\n",
    "\n",
    "loc_to_split = {}\n",
    "for split, locs in split_to_locs.items():\n",
    "    for loc in locs:\n",
    "        loc_to_split[loc] = split\n",
    "\n",
    "df['split'] = df['dataset_location'].map(loc_to_split.__getitem__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Optional) Compare against another set of splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_splits(splits_json_path1: str, splits_json_path2: str,\n",
    "                   name1: str = 'this', name2: str = 'other') -> None:\n",
    "    \"\"\"Compare the locations from two different splits.\n",
    "\n",
    "    Args:\n",
    "        splits_json_path[X]: str, path to splits.json\n",
    "        name[X]: str, name to use for comparison\n",
    "    \"\"\"\n",
    "    split_to_locs1 = load_splits(splits_json_path1)\n",
    "    split_to_locs2 = load_splits(splits_json_path2)\n",
    "\n",
    "    for split in SPLITS:\n",
    "        print(f'{name1} # of {split} locs:', len(split_to_locs1[split]))\n",
    "        print(f'{name2} # of {split} locs:', len(split_to_locs2[split]))\n",
    "        print(f'number of overlap {split} locs:', len(split_to_locs1[split] & split_to_locs2[split]))\n",
    "        print('===')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compare_splits(splits_json_path, '/path/to/other/splits.json')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample crops from each label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for label, group_df in df.groupby('label'):\n",
    "    group_df = group_df.sample(5)\n",
    "    imgs = []\n",
    "    for file in group_df['path']:\n",
    "        path = os.path.join(crops_dir, file)\n",
    "        imgs.append(mpimg.imread(path))\n",
    "    fig = plot_img_grid(imgs=imgs, row_h=3, col_w=3, ncols=5)\n",
    "    print(label)\n",
    "    display(group_df)\n",
    "    display(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Crop confidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['confidence'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/4), tight_layout=True)\n",
    "sns.boxplot(data=df, y='label', x='confidence', ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View distribution of locations and labels by locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "locs_per_split = df.groupby('split')['dataset_location'].nunique()[SPLITS]\n",
    "locs_per_split.loc['total'] = locs_per_split.sum()\n",
    "display(locs_per_split.to_frame())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "locations = (\n",
    "    df.groupby(['label', 'split'])['dataset_location'].nunique()\n",
    "    .unstack('split')[SPLITS]\n",
    "    .fillna(0)\n",
    "    .astype(int)\n",
    ")\n",
    "locations['total'] = locations.sum(axis=1)\n",
    "\n",
    "locations_frac = locations[SPLITS].div(locations['total'], axis=0)\n",
    "locations_all = pd.concat(\n",
    "    [locations_frac, locations], axis=1,\n",
    "    keys=['frac', 'counts'], sort=False)\n",
    "\n",
    "with disp_context:\n",
    "    display(locations_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# identify labels with extreme distributions\n",
    "with disp_context:\n",
    "    print('(test set < 5) or (test set < 10%)')\n",
    "    mask = (locations_all.loc[:, ('counts', 'test')] < 5) | (locations_all.loc[:, ('frac', 'test')] < 0.1)\n",
    "    display(locations_all.loc[mask])\n",
    "\n",
    "    print('(val set < 5) or (val set < 10%)')\n",
    "    mask = (locations_all.loc[:, ('counts', 'val')] < 5) | (locations_all.loc[:, ('frac', 'val')] < 0.1)\n",
    "    display(locations_all.loc[mask])\n",
    "\n",
    "    print('(train set < 10) or (train set < 40%)')\n",
    "    mask = (locations_all.loc[:, ('counts', 'train')] < 10) | (locations_all.loc[:, ('frac', 'train')] < 0.4)\n",
    "    display(locations_all.loc[mask])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View distribution of labels by split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_dist = (\n",
    "    df.groupby(['label', 'split']).size()\n",
    "    .unstack('split')[SPLITS]\n",
    "    .fillna(0)\n",
    "    .astype(int)\n",
    ")\n",
    "\n",
    "labels_dist_with_total = labels_dist.copy()\n",
    "labels_dist_with_total.loc['total'] = labels_dist.sum(axis=0)\n",
    "\n",
    "labels_dist_frac = labels_dist_with_total.div(labels_dist_with_total.sum(axis=1), axis=0)\n",
    "\n",
    "labels_dist_all = pd.concat([labels_dist_frac, labels_dist_with_total], axis=1,\n",
    "                            keys=['frac', 'counts'], sort=False)\n",
    "labels_dist_all.loc[:, ('counts', 'total')] = labels_dist_all.loc[:, 'counts'].sum(axis=1)\n",
    "\n",
    "with disp_context:\n",
    "    display(labels_dist_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# identify labels with extreme distributions\n",
    "with disp_context:\n",
    "    print('(test set < 300) and (test set < 9%)')\n",
    "    test_mask = (labels_dist_all.loc[:, ('counts', 'test')] < 300) & (labels_dist_all.loc[:, ('frac', 'test')] < 0.09)\n",
    "    print(test_mask.sum())\n",
    "\n",
    "    print('(val set < 300) and (val set < 9%)')\n",
    "    val_mask = (labels_dist_all.loc[:, ('counts', 'val')] < 300) & (labels_dist_all.loc[:, ('frac', 'val')] < 0.09)\n",
    "    print(val_mask.sum())\n",
    "\n",
    "    print('(train set < 1000) and (train set < 40%)')\n",
    "    train_mask = (labels_dist_all.loc[:, ('counts', 'train')] < 1000) & (labels_dist_all.loc[:, ('frac', 'train')] < 0.4)\n",
    "    print(train_mask.sum())\n",
    "    # display(labels_dist_all.loc[train_mask])\n",
    "\n",
    "    # combined\n",
    "    print((train_mask | val_mask | test_mask).sum())\n",
    "    display(labels_dist_all.loc[train_mask | val_mask | test_mask])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# approximate sample weights\n",
    "sample_weights = len(df) / (df['label'].value_counts() * df['label'].nunique())\n",
    "with disp_context:\n",
    "    display(sample_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/2), tight_layout=True)\n",
    "sns.countplot(y='label', hue='split', data=df, order=label_order, ax=ax, hue_order=SPLITS)\n",
    "\n",
    "# roughly equivalent to:\n",
    "# labels_dist.plot(kind='barh', figsize=(10, num_labels/2), width=0.8, ax=ax)\n",
    "# ax.invert_yaxis()\n",
    "# ax.grid(axis='y')\n",
    "# ax.set_xlabel('count')\n",
    "\n",
    "for i, p in enumerate(ax.patches):\n",
    "    if i < len(ax.patches) / 3:\n",
    "        ax.annotate(str(p.get_width()), (p.get_width() * 1.005, p.get_y() + 0.2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if necessary, zoom in the x-axis from the plot above\n",
    "# fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/2))\n",
    "# ax = sns.countplot(data=df, y='label', hue='split', order=label_order, ax=ax, hue_order=SPLITS)\n",
    "# ax.set_xlim(0, 5000)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_dist_norm = labels_dist / labels_dist.sum(axis=0) * 100\n",
    "with disp_context:\n",
    "    display(labels_dist_norm)\n",
    "\n",
    "labels_dist_norm = labels_dist_norm.stack('split').rename('% of split').reset_index()\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/2), tight_layout=True)\n",
    "ax.set_title('How much each class contributes to each split')\n",
    "sns.barplot(data=labels_dist_norm, y='label', x='% of split', hue='split', ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View distrbution of labels by split and dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# which datasets are represented in each split?\n",
    "with disp_context:\n",
    "    display(df.groupby(['label', 'split'])['dataset'].unique().unstack('split')[SPLITS])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_by_split_ds = df.groupby(['label', 'split', 'dataset']).size().rename('count')\n",
    "with disp_context:\n",
    "    display(labels_by_split_ds.unstack('split')[SPLITS].fillna(0).astype(int))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.catplot(data=labels_by_split_ds.reset_index(),\n",
    "            x='count', y='label', hue='split', col='dataset',\n",
    "            col_wrap=1, kind='bar', sharex=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View distribution of labels by split, dataset, and location\n",
    "\n",
    "For each label, dataset, and split:\n",
    "* plot a histogram of the number of crops per location."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_by_split_ds_loc = df.groupby(['label', 'dataset', 'location', 'split']).size().rename('count').reset_index()\n",
    "with disp_context:\n",
    "    display(labels_by_split_ds_loc.head())\n",
    "labels_by_split_ds_loc['split'] = labels_by_split_ds_loc['split'].astype('category')\n",
    "sns.catplot(data=labels_by_split_ds_loc,\n",
    "            col='label', y='dataset', x='count', hue='split',\n",
    "            kind='strip', dodge=True,\n",
    "            col_wrap=5, sharex=False, sharey=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cameratraps-classifier] *",
   "language": "python",
   "name": "conda-env-cameratraps-classifier-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
